{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.1 决策树模型的基本原理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.1.3 决策树模型的代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "决策树模型既可以做分类分析（即预测分类变量值），也可以做回归分析（即预测连续变量值），分别对应的模型为分类决策树模型（DecisionTreeClassifier）及回归决策树模型（DecisionTreeRegressor）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**1.分类决策树模型（DecisionTreeClassifier）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 0, 0, 1, 1]\n",
    "\n",
    "model = DecisionTreeClassifier(random_state=0)\n",
    "model.fit(X, y)\n",
    "\n",
    "print(model.predict([[5, 5]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果要同时预测多个数据，则可以写成如下形式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0 1]\n"
     ]
    }
   ],
   "source": [
    "print(model.predict([[5, 5], [7, 7], [9, 9]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**补充知识点：决策树可视化（供感兴趣的读者参考）**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过graphviz插件进行决策树可视化，graphviz插件的使用教程可以查看该文档:https://shimo.im/docs/Dcgw8H6WxgWrc8hq/ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"262pt\" height=\"314pt\"\r\n",
       " viewBox=\"0.00 0.00 262.00 314.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 310)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-310 258,-310 258,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"199,-306 108,-306 108,-223 199,-223 199,-306\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-290.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 7.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-275.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.48</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-260.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 3]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"145,-187 54,-187 54,-104 145,-104 145,-187\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-171.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 2.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-156.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.444</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 1]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 0</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M134.765,-222.907C130.786,-214.286 126.542,-205.09 122.427,-196.175\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"125.57,-194.634 118.202,-187.021 119.215,-197.567 125.57,-194.634\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"109.559\" y=\"-206.779\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"254,-179.5 163,-179.5 163,-111.5 254,-111.5 254,-179.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-164.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-149.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 2]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>0&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M172.582,-222.907C177.769,-211.873 183.399,-199.898 188.628,-188.773\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"191.822,-190.206 192.909,-179.667 185.487,-187.228 191.822,-190.206\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"201.386\" y=\"-199.486\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"91,-68 0,-68 0,-0 91,-0 91,-68\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-52.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M79.3924,-103.726C75.1643,-95.1527 70.6946,-86.0891 66.437,-77.4555\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"69.4838,-75.7203 61.9217,-68.2996 63.2056,-78.8164 69.4838,-75.7203\"/>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"200,-68 109,-68 109,-0 200,-0 200,-68\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-52.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 0]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 0</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>1&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M119.98,-103.726C124.286,-95.1527 128.839,-86.0891 133.175,-77.4555\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"136.413,-78.8067 137.774,-68.2996 130.158,-75.6647 136.413,-78.8067\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x13c3d0264a8>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1.如果不用显示中文，那么通过如下代码即可。安装graphviz稍微有些麻烦，可以参考上面的参考链接。\n",
    "from sklearn.tree import export_graphviz\n",
    "import graphviz\n",
    "\n",
    "import os  # 以下两行是环境变量配置，运行一次即可，相关知识点可以参考5.2.3节\n",
    "os.environ['PATH'] = os.pathsep + r'C:\\Program Files (x86)\\Graphviz2.38\\bin'\n",
    "\n",
    "dot_data = export_graphviz(model, out_file=None, class_names=['0', '1'])\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph  # 通过graph.render('决策树可视化')可在代码所在文件夹生成决策树可视化PDF文件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**补充知识点：random_state参数的作用解释**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"262pt\" height=\"314pt\"\r\n",
       " viewBox=\"0.00 0.00 262.00 314.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 310)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-310 258,-310 258,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"199,-306 108,-306 108,-223 199,-223 199,-306\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-290.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 6.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-275.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.48</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-260.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 3]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"145,-187 54,-187 54,-104 145,-104 145,-187\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-171.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 2.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-156.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.444</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 1]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 0</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M134.765,-222.907C130.786,-214.286 126.542,-205.09 122.427,-196.175\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"125.57,-194.634 118.202,-187.021 119.215,-197.567 125.57,-194.634\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"109.559\" y=\"-206.779\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"254,-179.5 163,-179.5 163,-111.5 254,-111.5 254,-179.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-164.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-149.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 2]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>0&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M172.582,-222.907C177.769,-211.873 183.399,-199.898 188.628,-188.773\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"191.822,-190.206 192.909,-179.667 185.487,-187.228 191.822,-190.206\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"201.386\" y=\"-199.486\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"91,-68 0,-68 0,-0 91,-0 91,-68\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-52.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M79.3924,-103.726C75.1643,-95.1527 70.6946,-86.0891 66.437,-77.4555\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"69.4838,-75.7203 61.9217,-68.2996 63.2056,-78.8164 69.4838,-75.7203\"/>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"200,-68 109,-68 109,-0 200,-0 200,-68\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-52.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 0]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 0</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>1&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M119.98,-103.726C124.286,-95.1527 128.839,-86.0891 133.175,-77.4555\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"136.413,-78.8067 137.774,-68.2996 130.158,-75.6647 136.413,-78.8067\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x13c4c61dac8>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 0, 0, 1, 1]\n",
    "\n",
    "model = DecisionTreeClassifier()  # 不设置random_state参数\n",
    "model.fit(X, y)\n",
    "\n",
    "# 生成可视化结果\n",
    "dot_data = export_graphviz(model, out_file=None, class_names=['0', '1'])\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"262pt\" height=\"314pt\"\r\n",
       " viewBox=\"0.00 0.00 262.00 314.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 310)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-310 258,-310 258,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"199,-306 108,-306 108,-223 199,-223 199,-306\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-290.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 7.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-275.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.48</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-260.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 3]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"153.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"145,-187 54,-187 54,-104 145,-104 145,-187\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-171.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 3.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-156.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.444</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 1]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"99.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 0</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M134.765,-222.907C130.786,-214.286 126.542,-205.09 122.427,-196.175\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"125.57,-194.634 118.202,-187.021 119.215,-197.567 125.57,-194.634\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"109.559\" y=\"-206.779\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"254,-179.5 163,-179.5 163,-111.5 254,-111.5 254,-179.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-164.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-149.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 2]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"208.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>0&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M172.582,-222.907C177.769,-211.873 183.399,-199.898 188.628,-188.773\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"191.822,-190.206 192.909,-179.667 185.487,-187.228 191.822,-190.206\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"201.386\" y=\"-199.486\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"91,-68 0,-68 0,-0 91,-0 91,-68\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-52.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [0, 1]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"45.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 1</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M79.3924,-103.726C75.1643,-95.1527 70.6946,-86.0891 66.437,-77.4555\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"69.4838,-75.7203 61.9217,-68.2996 63.2056,-78.8164 69.4838,-75.7203\"/>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"200,-68 109,-68 109,-0 200,-0 200,-68\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-52.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">gini = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = [2, 0]</text>\r\n",
       "<text text-anchor=\"middle\" x=\"154.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">class = 0</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>1&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M119.98,-103.726C124.286,-95.1527 128.839,-86.0891 133.175,-77.4555\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"136.413,-78.8067 137.774,-68.2996 130.158,-75.6647 136.413,-78.8067\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x13c4c6452e8>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 0, 0, 1, 1]\n",
    "\n",
    "model = DecisionTreeClassifier()  # 不设置random_state参数\n",
    "model.fit(X, y)\n",
    "\n",
    "# 生成可视化结果\n",
    "dot_data = export_graphviz(model, out_file=None, class_names=['0', '1'])\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到它们的节点划分方式是不同的，这样会导致同样的数据的预测结果有所不同，例如数据[7, 7]在左边的决策树中会被预测为类别0，而在右边的决策树中会被预测为类别1。\n",
    "\n",
    "此时有的读者就会有疑问了，为什么模型训练后会产生两颗不同的树呢，哪棵树是正确的呢？其实两颗树都是正确的，出现这种情况的原因，是因为根据“X[1]<=7”或者“X[0]<=6”进行节点划分时产生的基尼系数下降是一样的（都是0.48 - (0.6*0.444 + 0.4*0) = 0.2136），所以无论以哪种形式进行节点划分都是合理的。产生这一现象的原因大程度是因为数据量较少所以容易产生不同划分方式产生的基尼系数下降是一样的情况，当数据量较大时出现该现象的几率则较小。\n",
    "\n",
    "总的来说，对于同一个模型，不同的划分方式可能会导致最后的预测结果会有所不同，设置random_state参数（可以设置成0，也可以设置成1或123等任意数字）则能保证每次的划分方式都是一致的，使得每次运行的结果相同。这个概念对于初学者来说还是挺重要的，因为初学者往往会发现怎么每次运行同一个模型出来的结果都不一样，从而一头雾水，如果出现这种情况，那么设置一下random_state参数即可。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**2.回归决策树模型（DecisionTreeRegressor）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4.5]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeRegressor\n",
    "X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]\n",
    "y = [1, 2, 3, 4, 5]\n",
    "\n",
    "model = DecisionTreeRegressor(max_depth=2, random_state=0)\n",
    "model.fit(X, y)\n",
    "\n",
    "print(model.predict([[9, 9]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"398pt\" height=\"269pt\"\r\n",
       " viewBox=\"0.00 0.00 398.00 269.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 265)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-265 394,-265 394,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"236,-261 152,-261 152,-193 236,-193 236,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"194\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[1] &lt;= 5.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"194\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 2.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"194\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 5</text>\r\n",
       "<text text-anchor=\"middle\" x=\"194\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 3.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"185,-157 101,-157 101,-89 185,-89 185,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"143\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 2.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"143\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.25</text>\r\n",
       "<text text-anchor=\"middle\" x=\"143\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"143\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 1.5</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M177.442,-192.884C173.211,-184.422 168.604,-175.207 164.176,-166.352\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"167.252,-164.678 159.65,-157.299 160.991,-167.809 167.252,-164.678\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"151.744\" y=\"-177.316\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"289,-157 203,-157 203,-89 289,-89 289,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"246\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X[0] &lt;= 6.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"246\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.667</text>\r\n",
       "<text text-anchor=\"middle\" x=\"246\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"246\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.0</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>0&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M210.883,-192.884C215.197,-184.422 219.894,-175.207 224.409,-166.352\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"227.6,-167.798 229.024,-157.299 221.364,-164.619 227.6,-167.798\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"236.744\" y=\"-177.377\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"84,-53 0,-53 0,-0 84,-0 84,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"42\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"42\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"42\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 1.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M107.675,-88.9485C97.7023,-79.6175 86.8591,-69.4722 76.9108,-60.1641\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"79.2004,-57.5132 69.507,-53.2367 74.4178,-62.6247 79.2004,-57.5132\"/>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"186,-53 102,-53 102,-0 186,-0 186,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"144\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"144\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"144\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 2.0</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>1&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M143.35,-88.9485C143.437,-80.7153 143.531,-71.848 143.619,-63.4814\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"147.122,-63.2732 143.728,-53.2367 140.122,-63.1991 147.122,-63.2732\"/>\r\n",
       "</g>\r\n",
       "<!-- 5 -->\r\n",
       "<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"288,-53 204,-53 204,-0 288,-0 288,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"246\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"246\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"246\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 3.0</text>\r\n",
       "</g>\r\n",
       "<!-- 4&#45;&gt;5 -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>4&#45;&gt;5</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M246,-88.9485C246,-80.7153 246,-71.848 246,-63.4814\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"249.5,-63.2367 246,-53.2367 242.5,-63.2367 249.5,-63.2367\"/>\r\n",
       "</g>\r\n",
       "<!-- 6 -->\r\n",
       "<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"390,-53 306,-53 306,-0 390,-0 390,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"348\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.25</text>\r\n",
       "<text text-anchor=\"middle\" x=\"348\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"348\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4.5</text>\r\n",
       "</g>\r\n",
       "<!-- 4&#45;&gt;6 -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>4&#45;&gt;6</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M281.675,-88.9485C291.746,-79.6175 302.697,-69.4722 312.744,-60.1641\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"315.264,-62.6004 320.221,-53.2367 310.506,-57.4655 315.264,-62.6004\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x13c3de09dd8>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 生成可视化结果\n",
    "dot_data = export_graphviz(model, out_file=None)  # 回归决策树就没有class分类参数了\n",
    "graph = graphviz.Source(dot_data)\n",
    "graph"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
